{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "状态价值函数:\n",
    "\n",
    "V(state) = 所有动作求和 -> 概率(action) * Q(state,action)\n",
    "\n",
    "对这个式子做变形得到:\n",
    "\n",
    "V(state) = 所有动作求和 -> 现概率(action) * \\[旧概率(action) / 现概率(action)\\] * Q(state,action)\n",
    "\n",
    "初始时可以认为现概率和旧概率相等,但随着模型的更新,现概率会变化.\n",
    "\n",
    "式子中的Q(state,action)可以用蒙特卡洛法估计.\n",
    "\n",
    "按照策略梯度的理论,状态价值取决于动作的质量,所以只要最大化V函数,就可以得到最好的动作策略."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT70lEQVR4nO3db2xT97kH8K+dPwaSHIcAsZsmFkhFpRF/1gYIZ32x3uKR0Wgaa16sE+qyCoHKHFSaCWmRKAjUqyAm3a7dIEh3GvRNR5dKdGqUtspCG7SLISVtrkKAaL1qlajk2KXIx05GnMR+7gtuzq1LYDh//LPx9yMdCZ/fY/s5x/ibc87PTmwiIiAiUsCuugEiyl4MICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUkZZAB07dgzLly/HggULUF1dje7ublWtEJEiSgLo7bffRmNjIw4ePIhPP/0U69atQ01NDYLBoIp2iEgRm4ovo1ZXV2PDhg34wx/+AACIx+OoqKjAnj178Jvf/CbV7RCRIrmpfsLx8XH09PSgqanJWme32+H1euH3+6e9TzQaRTQatW7H43HcvHkTS5Ysgc1mm/eeiSg5IoJIJIKysjLY7Xc/0Up5AN24cQOxWAwulythvcvlwrVr16a9T3NzMw4dOpSK9ohoDg0NDaG8vPyu4ykPoJloampCY2Ojdds0TXg8HgwNDUHTNIWdEdF0wuEwKioqUFRUdM+6lAfQ0qVLkZOTg0AgkLA+EAjA7XZPex+HwwGHw3HHek3TGEBEaexfXSJJ+SxYfn4+qqqq0NnZaa2Lx+Po7OyEruupboeIFFJyCtbY2Ij6+nqsX78eGzduxO9+9zuMjo7ihRdeUNEOESmiJIB+9rOf4euvv8aBAwdgGAa+973v4YMPPrjjwjQRPdiUfA5otsLhMJxOJ0zT5DUgojR0v+9RfheMiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyiQdQOfOncOPf/xjlJWVwWaz4d13300YFxEcOHAADz30EBYuXAiv14t//OMfCTU3b97E9u3boWkaiouLsWPHDoyMjMxqQ4go8yQdQKOjo1i3bh2OHTs27fjRo0fxxhtv4MSJE7h48SIKCgpQU1ODsbExq2b79u3o7+9HR0cH2tracO7cOezatWvmW0FEmUlmAYCcOXPGuh2Px8Xtdstvf/tba10oFBKHwyF//vOfRUTkypUrAkA++eQTq+b9998Xm80mX3311X09r2maAkBM05xN+0Q0T+73PTqn14C++OILGIYBr9drrXM6naiurobf7wcA+P1+FBcXY/369VaN1+uF3W7HxYsXp33caDSKcDicsBBR5pvTADIMAwDgcrkS1rtcLmvMMAyUlpYmjOfm5qKkpMSq+a7m5mY4nU5rqaiomMu2iUiRjJgFa2pqgmma1jI0NKS6JSKaA3MaQG63GwAQCAQS1gcCAWvM7XYjGAwmjE9OTuLmzZtWzXc5HA5ompawEFHmm9MAWrFiBdxuNzo7O6114XAYFy9ehK7rAABd1xEKhdDT02PVnD17FvF4HNXV1XPZDhGludxk7zAyMoLPP//cuv3FF1+gt7cXJSUl8Hg82Lt3L1599VWsXLkSK1aswCuvvIKysjJs27YNAPDYY4/hRz/6EXbu3IkTJ05gYmICDQ0NeO6551BWVjZnG0ZEGSDZ6bWPPvpIANyx1NfXi8jtqfhXXnlFXC6XOBwO2bx5swwMDCQ8xjfffCM///nPpbCwUDRNkxdeeEEikcicT/ERkRr3+x61iYgozL8ZCYfDcDqdME2T14OI0tD9vkczYhaMiB5MDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhImaT/LA/RXBMR3Lr5FW7d/Mpal7uwCNrDj8FmsynsjOYbA4jSws3/uYThz9qt24XuldAefkxhR5QKPAWjtCDxmOoWSAEGEKUFkbjqFkgBBhClhzgDKBsxgCgNCI+AshQDiNKCCK8BZSMGEKkngPAULCslFUDNzc3YsGEDioqKUFpaim3btmFgYCChZmxsDD6fD0uWLEFhYSHq6uoQCAQSagYHB1FbW4tFixahtLQU+/btw+Tk5Oy3hjIXT8GyUlIB1NXVBZ/PhwsXLqCjowMTExPYsmULRkdHrZqXX34Z7733HlpbW9HV1YXr16/j2WeftcZjsRhqa2sxPj6O8+fP480338SpU6dw4MCBudsqyjDCI6BsJbMQDAYFgHR1dYmISCgUkry8PGltbbVqrl69KgDE7/eLiEh7e7vY7XYxDMOqaWlpEU3TJBqN3tfzmqYpAMQ0zdm0T2kiHpuUz//2n9J9Yqe1XHn3qMTjcdWt0Qzd73t0VteATNMEAJSUlAAAenp6MDExAa/Xa9WsWrUKHo8Hfr8fAOD3+7FmzRq4XC6rpqamBuFwGP39/dM+TzQaRTgcTljowcIPImanGQdQPB7H3r178eSTT2L16tUAAMMwkJ+fj+Li4oRal8sFwzCsmm+Hz9T41Nh0mpub4XQ6raWiomKmbVMaEhFILDGAbHbOj2SDGb/KPp8Ply9fxunTp+eyn2k1NTXBNE1rGRoamvfnpBSSOMbMxB8+CxeXKWqGUmlGX0ZtaGhAW1sbzp07h/Lycmu92+3G+Pg4QqFQwlFQIBCA2+22arq7uxMeb2qWbKrmuxwOBxwOx0xapQwgInecgtlz8xR1Q6mU1BGQiKChoQFnzpzB2bNnsWLFioTxqqoq5OXlobOz01o3MDCAwcFB6LoOANB1HX19fQgGg1ZNR0cHNE1DZWXlbLaFHiQ2noJlg6SOgHw+H9566y389a9/RVFRkXXNxul0YuHChXA6ndixYwcaGxtRUlICTdOwZ88e6LqOTZs2AQC2bNmCyspKPP/88zh69CgMw8D+/fvh8/l4lEMWmz1HdQuUAkkFUEtLCwDgqaeeSlh/8uRJ/PKXvwQAvPbaa7Db7airq0M0GkVNTQ2OHz9u1ebk5KCtrQ27d++GrusoKChAfX09Dh8+PLstoQcKL0JnB5uIiOomkhUOh+F0OmGaJjRNU90OzVJsIor+dw4jGv7aWvfwhm146PGt/I2IGep+36P8MUNpiUdA2YGvMqUlm43XgLIBA4jSEo+AsgNfZUpPDKCswFeZ0hJPwbIDA4jSEk/BsgNfZUpLNpudU/BZgAFE6kkcGfhxNJoDDCBS7nb4MICyEQOI1JM4wCOgrMQAIuWEAZS1GECknIhAeAqWlRhApB6PgLIWA4iU4wxY9mIAkXpxTsNnKwYQKSciPAXLUgwgUk/i4OeAshMDiJQTid2eiv82fg0jKzCASLnxkRBi0VvWbXtuPhxFSxV2RKnCACLlRGJIOAWz2WDLmdGfrKMMwwCiNGSDjX8XLCvwVaa0xADKDnyVKe3YbDb+StYswVeZ0hBPwbIFr/TRvIvH44hEInf9tPPo6GjCbRFBZGQEUVto2vrc3FwUFBTwNyY+ABhANO8Mw8DmzZsRiUSmHX/ikWXY/9wG67Zpmti19RkEQremrdd1HX/5y1/mpVdKLQYQzbtYLIbh4WGYpjntuKfYjhsTD+PLW2uQbx9Dafy/cH3YgHFzZNr6GzduzGe7lEJJnWi3tLRg7dq10DQNmqZB13W8//771vjY2Bh8Ph+WLFmCwsJC1NXVIRAIJDzG4OAgamtrsWjRIpSWlmLfvn2YnJycm62hjBSZXIL/jjyNGxMeXI+uRG/k3zAe48/GbJBUAJWXl+PIkSPo6enBpUuX8PTTT+MnP/kJ+vv7AQAvv/wy3nvvPbS2tqKrqwvXr1/Hs88+a90/FouhtrYW4+PjOH/+PN58802cOnUKBw4cmNutoowSlUUYjy/4v1s2jMY0TMR5fScryCwtXrxY/vjHP0ooFJK8vDxpbW21xq5evSoAxO/3i4hIe3u72O12MQzDqmlpaRFN0yQajd73c5qmKQDENM3Ztk8pMDg4KE6nc+o3z9+xPPHYo/L6kTY59Gq3HH71gvzHwePiLCy4a/1TTz0l8Xhc9WbRPdzve3TGx7mxWAytra0YHR2Fruvo6enBxMQEvF6vVbNq1Sp4PB74/X5s2rQJfr8fa9asgcvlsmpqamqwe/du9Pf34/HHH0+qh2vXrqGwsHCmm0ApYhgGYrHYXccDX1/H+c5/R3B8OfJsUWj4HGPRsbvWj46O4sqVK5wFS2MjI9Nfv/uupAOor68Puq5jbGwMhYWFOHPmDCorK9Hb24v8/HwUFxcn1LtcLhiGAeD2f8Rvh8/U+NTY3USjUUSjUet2OBwGcHu2hNeP0t+9puAB4KsbEbzd4Qfgv6/Hm5ycvOsFbUoP3/1oxd0kHUCPPvooent7YZom3nnnHdTX16OrqyvpBpPR3NyMQ4cO3bG+uroamqbN63PT7A0NDSE3d+4uKjudTui6ziOgNDZ1kPCvJP1x0/z8fDzyyCOoqqpCc3Mz1q1bh9dffx1utxvj4+MIhUIJ9YFAAG63GwDgdrvvmBWbuj1VM52mpiaYpmktQ0NDybZNRGlo1p93j8fjiEajqKqqQl5eHjo7O62xgYEBDA4OQtd1ALc/QNbX14dgMGjVdHR0QNM0VFZW3vU5HA6HNfU/tRBR5kvquLipqQlbt26Fx+NBJBLBW2+9hY8//hgffvghnE4nduzYgcbGRpSUlEDTNOzZswe6rmPTpk0AgC1btqCyshLPP/88jh49CsMwsH//fvh8PjgcjnnZQCJKX0kFUDAYxC9+8QsMDw/D6XRi7dq1+PDDD/HDH/4QAPDaa6/Bbrejrq4O0WgUNTU1OH78uHX/nJwctLW1Yffu3dB1HQUFBaivr8fhw4fndqsordjtdhQVFc3ZX74oKCiYk8ch9WwyV/8rUigcDsPpdMI0TZ6OZYDJyUkYhjFnAeRwOLBs2TJehE5j9/se5efdad7l5uaivLxcdRuUhvhLV4hIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyuaobmAkRAQCEw2HFnRDRdKbem1Pv1bvJyAD65ptvAAAVFRWKOyGie4lEInA6nXcdz8gAKikpAQAMDg7ec+MoUTgcRkVFBYaGhqBpmup2MgL32cyICCKRCMrKyu5Zl5EBZLffvnTldDr5n2IGNE3jfksS91ny7ufggBehiUgZBhARKZORAeRwOHDw4EE4HA7VrWQU7rfkcZ/NL5v8q3kyIqJ5kpFHQET0YGAAEZEyDCAiUoYBRETKZGQAHTt2DMuXL8eCBQtQXV2N7u5u1S0p09zcjA0bNqCoqAilpaXYtm0bBgYGEmrGxsbg8/mwZMkSFBYWoq6uDoFAIKFmcHAQtbW1WLRoEUpLS7Fv3z5MTk6mclOUOXLkCGw2G/bu3Wut4z5LEckwp0+flvz8fPnTn/4k/f39snPnTikuLpZAIKC6NSVqamrk5MmTcvnyZent7ZVnnnlGPB6PjIyMWDUvvviiVFRUSGdnp1y6dEk2bdok3//+963xyclJWb16tXi9Xvnss8+kvb1dli5dKk1NTSo2KaW6u7tl+fLlsnbtWnnppZes9dxnqZFxAbRx40bx+XzW7VgsJmVlZdLc3Kywq/QRDAYFgHR1dYmISCgUkry8PGltbbVqrl69KgDE7/eLiEh7e7vY7XYxDMOqaWlpEU3TJBqNpnYDUigSicjKlSulo6NDfvCDH1gBxH2WOhl1CjY+Po6enh54vV5rnd1uh9frhd/vV9hZ+jBNE8D/f2G3p6cHExMTCfts1apV8Hg81j7z+/1Ys2YNXC6XVVNTU4NwOIz+/v4Udp9aPp8PtbW1CfsG4D5LpYz6MuqNGzcQi8USXnQAcLlcuHbtmqKu0kc8HsfevXvx5JNPYvXq1QAAwzCQn5+P4uLihFqXywXDMKya6fbp1NiD6PTp0/j000/xySef3DHGfZY6GRVAdG8+nw+XL1/G3//+d9WtpLWhoSG89NJL6OjowIIFC1S3k9Uy6hRs6dKlyMnJuWM2IhAIwO12K+oqPTQ0NKCtrQ0fffQRysvLrfVutxvj4+MIhUIJ9d/eZ263e9p9OjX2oOnp6UEwGMQTTzyB3Nxc5ObmoqurC2+88QZyc3Phcrm4z1IkowIoPz8fVVVV6OzstNbF43F0dnZC13WFnakjImhoaMCZM2dw9uxZrFixImG8qqoKeXl5CftsYGAAg4OD1j7TdR19fX0IBoNWTUdHBzRNQ2VlZWo2JIU2b96Mvr4+9Pb2Wsv69euxfft269/cZymi+ip4sk6fPi0Oh0NOnTolV65ckV27dklxcXHCbEQ22b17tzidTvn4449leHjYWv75z39aNS+++KJ4PB45e/asXLp0SXRdF13XrfGpKeUtW7ZIb2+vfPDBB7Js2bKsmlL+9iyYCPdZqmRcAImI/P73vxePxyP5+fmyceNGuXDhguqWlAEw7XLy5Emr5tatW/KrX/1KFi9eLIsWLZKf/vSnMjw8nPA4X375pWzdulUWLlwoS5culV//+tcyMTGR4q1R57sBxH2WGvx1HESkTEZdAyKiBwsDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUuZ/ATLnxgb/cBGNAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        #没坚持到最后,扣分\n",
    "        if over and self.step_n < 200:\n",
    "            reward = -1000\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.5028, 0.4972],\n",
       "         [0.5123, 0.4877]], grad_fn=<SoftmaxBackward0>),\n",
       " tensor([[-0.0145],\n",
       "         [ 0.0464]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#定义模型\n",
    "model_action = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "model_value = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_action(torch.randn(2, 4)), model_value(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_7761/1112667714.py:34: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor(state).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-979.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    state = []\n",
    "    action = []\n",
    "    reward = []\n",
    "    next_state = []\n",
    "    over = []\n",
    "\n",
    "    s = env.reset()\n",
    "    o = False\n",
    "    while not o:\n",
    "        #根据概率采样\n",
    "        prob = model_action(torch.FloatTensor(s).reshape(1, 4))[0].tolist()\n",
    "        a = random.choices(range(2), weights=prob, k=1)[0]\n",
    "\n",
    "        ns, r, o = env.step(a)\n",
    "\n",
    "        state.append(s)\n",
    "        action.append(a)\n",
    "        reward.append(r)\n",
    "        next_state.append(ns)\n",
    "        over.append(o)\n",
    "\n",
    "        s = ns\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    state = torch.FloatTensor(state).reshape(-1, 4)\n",
    "    action = torch.LongTensor(action).reshape(-1, 1)\n",
    "    reward = torch.FloatTensor(reward).reshape(-1, 1)\n",
    "    next_state = torch.FloatTensor(next_state).reshape(-1, 4)\n",
    "    over = torch.LongTensor(over).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over, reward.sum().item()\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over, reward_sum = play()\n",
    "\n",
    "reward_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=1e-3)\n",
    "optimizer_value = torch.optim.Adam(model_value.parameters(), lr=1e-2)\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([22, 1])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, reward, next_state, over):\n",
    "    requires_grad(model_action, False)\n",
    "    requires_grad(model_value, True)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        target = model_value(next_state)\n",
    "    target = target * 0.98 * (1 - over) + reward\n",
    "\n",
    "    #每批数据反复训练10次\n",
    "    for _ in range(10):\n",
    "        #计算value\n",
    "        value = model_value(state)\n",
    "\n",
    "        loss = torch.nn.functional.mse_loss(value, target)\n",
    "        loss.backward()\n",
    "        optimizer_value.step()\n",
    "        optimizer_value.zero_grad()\n",
    "\n",
    "    #减去value相当于去基线\n",
    "    return (target - value).detach()\n",
    "\n",
    "\n",
    "value = train_value(state, reward, next_state, over)\n",
    "\n",
    "value.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "481.4684753417969"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state, action, value):\n",
    "    requires_grad(model_action, True)\n",
    "    requires_grad(model_value, False)\n",
    "\n",
    "    #计算当前state的价值,其实就是Q(state,action),这里是用蒙特卡洛法估计的\n",
    "    delta = []\n",
    "    for i in range(len(value)):\n",
    "        s = 0\n",
    "        for j in range(i, len(value)):\n",
    "            s += value[j] * (0.98 * 0.95)**(j - i)\n",
    "        delta.append(s)\n",
    "    delta = torch.FloatTensor(delta).reshape(-1, 1)\n",
    "\n",
    "    #更新前的动作概率\n",
    "    with torch.no_grad():\n",
    "        prob_old = model_action(state).gather(dim=1, index=action)\n",
    "\n",
    "    #每批数据反复训练10次\n",
    "    for _ in range(10):\n",
    "        #更新后的动作概率\n",
    "        prob_new = model_action(state).gather(dim=1, index=action)\n",
    "\n",
    "        #求出概率的变化\n",
    "        ratio = prob_new / prob_old\n",
    "\n",
    "        #计算截断的和不截断的两份loss,取其中小的\n",
    "        surr1 = ratio * delta\n",
    "        surr2 = ratio.clamp(0.8, 1.2) * delta\n",
    "\n",
    "        loss = -torch.min(surr1, surr2).mean()\n",
    "\n",
    "        #更新参数\n",
    "        loss.backward()\n",
    "        optimizer_action.step()\n",
    "        optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state, action, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 57.404075622558594 -978.6\n",
      "100 -18.518997192382812 200.0\n",
      "200 0.9326910972595215 200.0\n",
      "300 -0.06682347506284714 200.0\n",
      "400 -0.040946170687675476 200.0\n",
      "500 -0.3048292398452759 200.0\n",
      "600 -0.3876965045928955 200.0\n",
      "700 2.3514397144317627 200.0\n",
      "800 1.754111886024475 200.0\n",
      "900 25.258014678955078 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_action.train()\n",
    "    model_value.train()\n",
    "\n",
    "    #训练N局\n",
    "    for epoch in range(1000):\n",
    "        #一个epoch最少玩N步\n",
    "        steps = 0\n",
    "        while steps < 200:\n",
    "            state, action, reward, next_state, over, _ = play()\n",
    "            steps += len(state)\n",
    "\n",
    "            #训练两个模型\n",
    "            delta = train_value(state, reward, next_state, over)\n",
    "            loss = train_action(state, action, delta)\n",
    "\n",
    "        if epoch % 100 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, loss, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT7ElEQVR4nO3df0xb570/8LcN2IHAMYEUeyi4ye2qZig/1pGEnFX6bt/WDWu507Jwpa2KWlZF6W1ioqRMkYbUpkpuJaLsj2zdUvLH7pJKV10mJrGpiLZi0JJb1QmJM64ITdCimwqUxHZDxDHQYgz+3D8izuoGMgzEj23eL+lI8fM8Pv6cB593zg+MLSIiICJSwKq6ACJauhhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkjLIAOnHiBFavXo1ly5ahsrIS3d3dqkohIkWUBNAf//hH1NfX4/XXX8elS5ewceNGVFVVIRQKqSiHiBSxqPgwamVlJTZv3ozf/va3AIBYLIaysjLs27cPv/jFL5JdDhEpkp3sF5yYmIDf70dDQ4PZZrVa4fF44PP5ZnxOJBJBJBIxH8diMdy5cwfFxcWwWCwPvGYiSoyIYGRkBKWlpbBaZz/RSnoA3b59G1NTU3A6nXHtTqcTV69enfE5jY2NOHz4cDLKI6JFNDg4iFWrVs3an/QAmo+GhgbU19ebjw3DgNvtxuDgIDRNU1gZEc0kHA6jrKwMBQUF9x2X9ABauXIlsrKyEAwG49qDwSBcLteMz7Hb7bDb7fe0a5rGACJKYf/sEknS74LZbDZUVFSgo6PDbIvFYujo6ICu68kuh4gUUnIKVl9fj9raWmzatAlbtmzBr371K4yNjeHFF19UUQ4RKaIkgH7yk5/g888/x6FDhxAIBPDtb38b77///j0Xpokosyn5PaCFCofDcDgcMAyD14CIUtBc91F+FoyIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKJBxAZ8+exQ9/+EOUlpbCYrHgz3/+c1y/iODQoUP4xje+gdzcXHg8Hvz973+PG3Pnzh3s3LkTmqahsLAQu3btwujo6II2hIjST8IBNDY2ho0bN+LEiRMz9h87dgxvvvkmTp48ifPnz2P58uWoqqrC+Pi4OWbnzp3o6+tDe3s7WltbcfbsWbz00kvz3woiSk+yAACkpaXFfByLxcTlcskvf/lLs214eFjsdrv84Q9/EBGRTz/9VADIhQsXzDHvvfeeWCwWuXHjxpxe1zAMASCGYSykfCJ6QOa6jy7qNaDr168jEAjA4/GYbQ6HA5WVlfD5fAAAn8+HwsJCbNq0yRzj8XhgtVpx/vz5GdcbiUQQDofjFiJKf4saQIFAAADgdDrj2p1Op9kXCARQUlIS15+dnY2ioiJzzNc1NjbC4XCYS1lZ2WKWTUSKpMVdsIaGBhiGYS6Dg4OqSyKiRbCoAeRyuQAAwWAwrj0YDJp9LpcLoVAorn9ychJ37twxx3yd3W6HpmlxCxGlv0UNoDVr1sDlcqGjo8NsC4fDOH/+PHRdBwDouo7h4WH4/X5zTGdnJ2KxGCorKxezHCJKcdmJPmF0dBTXrl0zH1+/fh09PT0oKiqC2+3GgQMH8MYbb+DRRx/FmjVr8Nprr6G0tBTbt28HAHzrW9/CD37wA+zevRsnT55ENBpFXV0dfvrTn6K0tHTRNoyI0kCit9c+/PBDAXDPUltbKyJ3b8W/9tpr4nQ6xW63y1NPPSX9/f1x6xgaGpLnnntO8vPzRdM0efHFF2VkZGTRb/ERkRpz3UctIiIK829ewuEwHA4HDMPg9SCiFDTXfTQt7oIRUWZiABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKJPy1PJTaIiNDGA3842uTrNk2ONzrYc3ij5pSD9+VGWYsdB3/2/mf5uOc5SuwrvQxBhClJJ6CZTyBSEx1EUQzYgBlOhGAAUQpigG0BEiMAUSpiQGU4UTk7lEQUQpiAGU8XgOi1MUAynTCAKLUlVAANTY2YvPmzSgoKEBJSQm2b9+O/v7+uDHj4+Pwer0oLi5Gfn4+ampqEAwG48YMDAyguroaeXl5KCkpwcGDBzE5ObnwraGZMYAoRSUUQF1dXfB6vTh37hza29sRjUaxbds2jI2NmWNeeeUVvPvuu2hubkZXVxdu3ryJHTt2mP1TU1Oorq7GxMQEPvnkE7z99ts4ffo0Dh06tHhbRSYRuXsdiCgVyQKEQiEBIF1dXSIiMjw8LDk5OdLc3GyOuXLligAQn88nIiJtbW1itVolEAiYY5qamkTTNIlEInN6XcMwBIAYhrGQ8jPS0LUL0n1yt7n4T+2XL4ZuqC6Llpi57qMLugZkGAYAoKioCADg9/sRjUbh8XjMMWvXroXb7YbP5wMA+Hw+rF+/Hk6n0xxTVVWFcDiMvr6+GV8nEokgHA7HLTRHvAZEKWzeARSLxXDgwAE88cQTWLduHQAgEAjAZrOhsLAwbqzT6UQgEDDHfDV8pvun+2bS2NgIh8NhLmVlZfMtO/NZLHEPRQQSm1JUDNH9zTuAvF4vLl++jDNnzixmPTNqaGiAYRjmMjg4+MBfM13ZC4phzVlmPo5FxzExekdhRUSzm9cnFOvq6tDa2oqzZ89i1apVZrvL5cLExASGh4fjjoKCwSBcLpc5pru7O25903fJpsd8nd1uh91un0+pS47Fmg3LPUdBPAWj1JTQEZCIoK6uDi0tLejs7MSaNWvi+isqKpCTk4OOjg6zrb+/HwMDA9B1HQCg6zp6e3sRCoXMMe3t7dA0DeXl5QvZFgJgsVjvOQ0jSlUJHQF5vV688847+Mtf/oKCggLzmo3D4UBubi4cDgd27dqF+vp6FBUVQdM07Nu3D7quY+vWrQCAbdu2oby8HM8//zyOHTuGQCCAV199FV6vl0c5i8FqBcAAovSQUAA1NTUBAL7//e/HtZ86dQo/+9nPAADHjx+H1WpFTU0NIpEIqqqq8NZbb5ljs7Ky0Nraij179kDXdSxfvhy1tbU4cuTIwraEANw9Avr6KRhRqrKIpN9vqYXDYTgcDhiGAU3TVJeTUsbDn+NKSyMmx0fNtkee/ncU/UuFwqpoqZnrPsrPgmUYi4WnYJQ+GEAZhhehKZ0wgDKNxcLjH0obDKAMY7HyCIjSBwMo0/AaEKURBlCGme0aUBre7KQlgAGUaRg+lEYYQEsBPw1PKYoBtATww6iUqhhASwD/HhClKgbQEsAvJqRUxQBaAngKRqmKAbQU8BSMUhQDaAngKRilKgbQEiDCIyBKTQygJYBHQJSqGEBLAG/DU6piAGWge74VIzapqBKi+2MAZRirNRt2R0lc25d3biqqhuj+GECZxmKBNSsnrolHQJSqGEAZ6O7fhSZKfXynZiCLlT9WSg98p2YYi8UCWLJUl0E0JwygDMQjIEoXCX0zKqknIhgbG8Pk5MwXlkViiEbj+6LRSRiGMes6CwoKkJXFoyZKPgZQGtq7dy86Oztn7LNagL3/uh5Pbiwz2z7++L/xH97jM47Pzc3FX//6Vzz88MMPpFai+2EApaGhoSHcuHFjxj6LBfjcWI/e0f+HSGw5Hs7tw/j44Kzj8/LyZj2aInrQErpY0NTUhA0bNkDTNGiaBl3X8d5775n94+Pj8Hq9KC4uRn5+PmpqahAMBuPWMTAwgOrqauTl5aGkpAQHDx7kDrCIsrPtGNN24EbkMdyOluF/Rv4/hiedqssimlFCAbRq1SocPXoUfr8fFy9exJNPPokf/ehH6OvrAwC88sorePfdd9Hc3Iyuri7cvHkTO3bsMJ8/NTWF6upqTExM4JNPPsHbb7+N06dP49ChQ4u7VUuY1ZqN5Voppr8bbFLsGI/lqS2KaDayQCtWrJDf/e53Mjw8LDk5OdLc3Gz2XblyRQCIz+cTEZG2tjaxWq0SCATMMU1NTaJpmkQikTm/pmEYAkAMw1ho+WknFovJs88+KwBmXKzWLHnhuQY58sY5OfxGtxw/2iaNe/9t1vF5eXly7do11ZtFGWau++i8rwFNTU2hubkZY2Nj0HUdfr8f0WgUHo/HHLN27Vq43W74fD5s3boVPp8P69evh9P5j1OCqqoq7NmzB319fXj88ccTquHq1avIz8+f7yakrdHR0Vn7YrEp9F36LwQDVxCJ5cFp+wxDdwKzjhcRXLt2DZFI5EGUSkvU/d6jX5VwAPX29kLXdYyPjyM/Px8tLS0oLy9HT08PbDYbCgsL48Y7nU4EAnd3gEAgEBc+0/3TfbOJRCJxO0g4HAYAGIaxJK8fRaPR+/b7+weB/sE5rUtEMDIyguHh4UWojOiusbGxOY1LOIAee+wx9PT0wDAM/OlPf0JtbS26uroSLjARjY2NOHz48D3tlZWV0DTtgb52qhERrFixYtHWZ7Va8fjjj+ORRx5ZtHUSTR8k/DMJ/8qszWbDN7/5TVRUVKCxsREbN27Er3/9a7hcLkxMTNzzP2kwGITL5QIAuFyue+6KTT+eHjOThoYGGIZhLoODc/vfnYhS24J/Zz8WiyESiaCiogI5OTno6Ogw+/r7+zEwMABd1wEAuq6jt7cXoVDIHNPe3g5N01BeXj7ra9jtdvPW//RCROkvoVOwhoYGPPPMM3C73RgZGcE777yDjz76CB988AEcDgd27dqF+vp6FBUVQdM07Nu3D7quY+vWrQCAbdu2oby8HM8//zyOHTuGQCCAV199FV6vF3a7/YFsIBGlroQCKBQK4YUXXsCtW7fgcDiwYcMGfPDBB3j66acBAMePH4fVakVNTQ0ikQiqqqrw1ltvmc/PyspCa2sr9uzZA13XsXz5ctTW1uLIkSOLu1UZLi8vb9GOAvPy8mDlh1dJEYuIiOoiEhUOh+FwOGAYxpI7HRMR3L59G+Pj44uyPovFApfLhexsfiqHFs9c91G+69KMxWLBQw89pLoMokXBY28iUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkTLbqAuZDRAAA4XBYcSVENJPpfXN6X51NWgbQ0NAQAKCsrExxJUR0PyMjI3A4HLP2p2UAFRUVAQAGBgbuu3EULxwOo6ysDIODg9A0TXU5aYFzNj8igpGREZSWlt53XFoGkNV699KVw+Hgm2IeNE3jvCWIc5a4uRwc8CI0ESnDACIiZdIygOx2O15//XXY7XbVpaQVzlviOGcPlkX+2X0yIqIHJC2PgIgoMzCAiEgZBhARKcMAIiJl0jKATpw4gdWrV2PZsmWorKxEd3e36pKUaWxsxObNm1FQUICSkhJs374d/f39cWPGx8fh9XpRXFyM/Px81NTUIBgMxo0ZGBhAdXU18vLyUFJSgoMHD2JycjKZm6LM0aNHYbFYcODAAbONc5YkkmbOnDkjNptNfv/730tfX5/s3r1bCgsLJRgMqi5NiaqqKjl16pRcvnxZenp65NlnnxW32y2jo6PmmJdfflnKysqko6NDLl68KFu3bpXvfve7Zv/k5KSsW7dOPB6P/O1vf5O2tjZZuXKlNDQ0qNikpOru7pbVq1fLhg0bZP/+/WY75yw50i6AtmzZIl6v13w8NTUlpaWl0tjYqLCq1BEKhQSAdHV1iYjI8PCw5OTkSHNzsznmypUrAkB8Pp+IiLS1tYnVapVAIGCOaWpqEk3TJBKJJHcDkmhkZEQeffRRaW9vl+9973tmAHHOkietTsEmJibg9/vh8XjMNqvVCo/HA5/Pp7Cy1GEYBoB/fGDX7/cjGo3GzdnatWvhdrvNOfP5fFi/fj2cTqc5pqqqCuFwGH19fUmsPrm8Xi+qq6vj5gbgnCVTWn0Y9fbt25iamor7oQOA0+nE1atXFVWVOmKxGA4cOIAnnngC69atAwAEAgHYbDYUFhbGjXU6nQgEAuaYmeZ0ui8TnTlzBpcuXcKFCxfu6eOcJU9aBRDdn9frxeXLl/Hxxx+rLiWlDQ4OYv/+/Whvb8eyZctUl7OkpdUp2MqVK5GVlXXP3YhgMAiXy6WoqtRQV1eH1tZWfPjhh1i1apXZ7nK5MDExgeHh4bjxX50zl8s145xO92Uav9+PUCiE73znO8jOzkZ2dja6urrw5ptvIjs7G06nk3OWJGkVQDabDRUVFejo6DDbYrEYOjo6oOu6wsrUERHU1dWhpaUFnZ2dWLNmTVx/RUUFcnJy4uasv78fAwMD5pzpuo7e3l6EQiFzTHt7OzRNQ3l5eXI2JImeeuop9Pb2oqenx1w2bdqEnTt3mv/mnCWJ6qvgiTpz5ozY7XY5ffq0fPrpp/LSSy9JYWFh3N2IpWTPnj3icDjko48+klu3bpnLF198YY55+eWXxe12S2dnp1y8eFF0XRdd183+6VvK27Ztk56eHnn//ffloYceWlK3lL96F0yEc5YsaRdAIiK/+c1vxO12i81mky1btsi5c+dUl6QMgBmXU6dOmWO+/PJL2bt3r6xYsULy8vLkxz/+sdy6dStuPZ999pk888wzkpubKytXrpSf//znEo1Gk7w16nw9gDhnycE/x0FEyqTVNSAiyiwMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISJn/A5p6q/Yfd4BiAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
